from typing import TYPE_CHECKING
from os.path import join
import logging
import tempfile
import shutil
from functools import lru_cache

import click
import numpy as np

from rastervision.pipeline.pipeline import Pipeline
from rastervision.core.box import Box
from rastervision.core.data_sample import DataSample
from rastervision.core.data import Scene, Labels
from rastervision.core.backend import Backend
from rastervision.pipeline.file_system.utils import (
    download_if_needed, zipdir, get_local_path, upload_or_copy, make_dir,
    sync_from_dir, file_exists)

log = logging.getLogger(__name__)

if TYPE_CHECKING:
    from rastervision.core.rv_pipeline import RVPipelineConfig

ALL_COMMANDS = ['analyze', 'chip', 'train', 'predict', 'eval', 'bundle']
SPLITTABLE_COMMANDS = ['chip', 'predict']
GPU_COMMANDS = ['train', 'predict']


class RVPipeline(Pipeline):
    """Base class of all Raster Vision Pipelines.

    This can be subclassed to implement Pipelines for different computer vision
    tasks over geospatial imagery. The commands and what they produce include:

    - analyze: metrics on the imagery and labels
    - chip: small training and validation images taken from larger scenes
    - train: model trained on chips
    - predict: predictions over entire validation and test scenes
    - eval: evaluation metrics for predictions generated by model
    - bundle: bundle containing model and any other files needed to make

    predictions using the Predictor.
    """

    def __init__(self, config: 'RVPipelineConfig', tmp_dir: str):
        super().__init__(config, tmp_dir)
        self.backend: 'Backend | None' = None
        self.config: 'RVPipelineConfig'

    @property
    def commands(self):
        commands = ALL_COMMANDS[:]
        if len(self.config.analyzers) == 0 and 'analyze' in commands:
            commands.remove('analyze')
            click.secho("Skipping 'analyze' command...", fg='green', bold=True)
        commands = self.config.backend.filter_commands(commands)
        return commands

    @property
    def split_commands(self):
        return self.config.backend.filter_commands(SPLITTABLE_COMMANDS)

    @property
    def gpu_commands(self):
        return self.config.backend.filter_commands(GPU_COMMANDS)

    def analyze(self):
        """Run each analyzer over training scenes."""
        dataset = self.config.dataset
        class_config = dataset.class_config

        scene_id_to_cfg = {s.id: s for s in dataset.all_scenes}

        @lru_cache(maxsize=len(dataset.all_scenes))
        def build_scene(scene_id: str) -> Scene:
            cfg = scene_id_to_cfg[scene_id]
            scene = cfg.build(
                class_config, self.tmp_dir, use_transformers=False)
            return scene

        # build and run each AnalyzerConfig for each scene group
        for a in self.config.analyzers:
            for group_name, group_ids in dataset.scene_groups.items():
                if len(group_ids) == 0:
                    log.info(f'Skipping scene group "{group_name}". '
                             'Empty scene group.')
                    continue
                group_scenes = (build_scene(id) for id in group_ids)
                analyzer = a.build(scene_group=(group_name, group_scenes))

                log.info(f'Running {type(analyzer).__name__} on '
                         f'scene group "{group_name}"...')
                analyzer.process(group_scenes, self.tmp_dir)

    def get_train_windows(self, scene: Scene) -> list[Box]:
        """Return the training windows for a Scene.

        Each training window represents the spatial extent of a training chip to
        generate.

        Args:
            scene: Scene to generate windows for
        """
        raise NotImplementedError()

    def get_train_labels(self, window: Box, scene: Scene) -> Labels:
        """Return the training labels in a window for a scene.

        Returns:
            Labels that lie within window
        """
        raise NotImplementedError()

    def chip(self, split_ind: int = 0, num_splits: int = 1):
        """Save training and validation chips."""
        cfg = self.config
        log.info(f'Chip options: {cfg.chip_options}')
        dataset = cfg.dataset.get_split_config(split_ind, num_splits)
        if not dataset.train_scenes and not dataset.validation_scenes:
            return
        backend = cfg.backend.build(cfg, self.tmp_dir)
        backend.chip_dataset(dataset, cfg.chip_options)

    def train(self):
        """Train a model and save it."""
        backend = self.config.backend.build(self.config, self.tmp_dir)
        backend.train(source_bundle_uri=self.config.source_bundle_uri)

    def post_process_sample(self, sample: DataSample) -> DataSample:
        """Post-process sample in pipeline-specific way.

        This should be called before writing a sample during chipping.
        """
        return sample

    def post_process_batch(self, windows: list[Box], chips: np.ndarray,
                           labels: Labels) -> Labels:
        """Post-process a batch of predictions."""
        return labels

    def post_process_predictions(self, labels: Labels, scene: Scene) -> Labels:
        """Post-process all labels at end of prediction."""
        return labels

    def predict(self, split_ind=0, num_splits=1):
        """Make predictions over each validation and test scene.

        This uses a sliding window.
        """
        if self.backend is None:
            self.build_backend()

        class_config = self.config.dataset.class_config
        dataset = self.config.dataset.get_split_config(split_ind, num_splits)

        for scene_config in (dataset.validation_scenes + dataset.test_scenes):
            scene = scene_config.build(class_config, self.tmp_dir)
            labels = self.predict_scene(scene)
            scene.label_store.save(labels)

    def predict_scene(self, scene: Scene) -> Labels:
        if self.backend is None:
            self.build_backend()
        labels = self.backend.predict_scene(
            scene, predict_options=self.config.predict_options)
        labels = self.post_process_predictions(labels, scene)
        return labels

    def eval(self):
        """Evaluate predictions against ground truth."""
        dataset = self.config.dataset
        class_config = dataset.class_config
        # it might make sense to make excluded_groups a field in an EvalConfig
        # in the future
        excluded_groups = ['train_scenes']

        scene_id_to_cfg = {s.id: s for s in dataset.all_scenes}

        @lru_cache(maxsize=len(dataset.all_scenes))
        def build_scene(scene_id: str) -> Scene:
            cfg = scene_id_to_cfg[scene_id]
            scene = cfg.build(
                class_config, self.tmp_dir, use_transformers=True)
            return scene

        # build and run each EvaluatorConfig for each scene group
        for e in self.config.evaluators:
            for group_name, group_ids in dataset.scene_groups.items():
                if group_name in excluded_groups:
                    continue
                if len(group_ids) == 0:
                    log.info(f'Skipping scene group "{group_name}". '
                             'Empty scene group.')
                    continue
                group_scenes = (build_scene(id) for id in group_ids)
                evaluator = e.build(
                    class_config, scene_group=(group_name, group_scenes))

                log.info(f'Running {type(evaluator).__name__} on '
                         f'scene group "{group_name}"...')
                try:
                    evaluator.process(group_scenes, self.tmp_dir)
                except FileNotFoundError:
                    log.warning(f'Skipping scene group "{group_name}". '
                                'Either labels or predictions are missing for '
                                'some scene.')

    def bundle(self):
        """Save a model bundle with whatever is needed to make predictions.

        The model bundle is a zip file and it is used by the Predictor and
        predict CLI subcommand.
        """
        with tempfile.TemporaryDirectory(dir=self.tmp_dir) as tmp_dir:
            bundle_dir = join(tmp_dir, 'bundle')
            make_dir(bundle_dir)

            for fn in self.config.backend.get_bundle_filenames():
                path = download_if_needed(
                    join(self.config.train_uri, fn), tmp_dir)
                shutil.copy(path, join(bundle_dir, fn))

            if file_exists(self.config.analyze_uri, include_dir=True):
                analyze_dst = join(bundle_dir, 'analyze')
                sync_from_dir(self.config.analyze_uri, analyze_dst)

            path = download_if_needed(self.config.get_config_uri(), tmp_dir)
            shutil.copy(path, join(bundle_dir, 'pipeline-config.json'))

            model_bundle_uri = self.config.get_model_bundle_uri()
            model_bundle_path = get_local_path(model_bundle_uri, self.tmp_dir)
            zipdir(bundle_dir, model_bundle_path)
            upload_or_copy(model_bundle_path, model_bundle_uri)

    def build_backend(self, uri: str | None = None) -> None:
        self.backend = self.config.backend.build(self.config, self.tmp_dir)
        self.backend.load_model(uri)
